{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Note\n",
    "\n",
    "Please view the [README](https://github.com/deeplearning4j/dl4j-examples/tree/overhaul_tutorials/tutorials/README.md) to learn about installing, setting up dependencies, and importing notebooks in Zeppelin"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Background\n",
    "\n",
    "<img src=\"\" style=\"float:left; margin:0 10px 10px 0; display:block\">\n",
    "Sometimes, deep learning is just one piece of the whole project. You may have a time series problem requiring advanced analysis and you need to use more than just a neural network. Trajectory clustering can be a difficult problem to solve when your data isn't quite \"even\". Marine [Automatic Identification System (AIS)](https://en.wikipedia.org/wiki/Automatic_identification_system) is an open system for marine broadcasting of positions. It primarily helps collision avoidance and marine authorities to monitor marine traffic.\n",
    "\n",
    "What if you wanted to determine the most popular routes? Or take it one step further and identify anomalous traffic? Not everything can be done with a single neural network. Furthermore, AIS data for 1 year is over 100GB compressed. You'll need more than just a desktop computer to analyze it seriously.\n",
    "\n",
    "#### Sequence-to-sequence Autoencoders\n",
    "\n",
    "As you learned in the *Basic Autoencoder* tutorial, applications of autoencoders in data science include dimensionality reduction and data denoising. Instead of using dense layers in an autoencoder, you can swap out simple MLPs for LSTMs. That same network using LSTMs are sequence-to-sequence autoencoders and are effective at capturing temporal structure.\n",
    "\n",
    "In the case of AIS data, coordinates can be reported at irregular intervals over time. Not all time series for a single ship have an equal length - there's high dimensionality in the data. Before deep learning was used, other techniques like [dynamic time warping](https://en.wikipedia.org/wiki/Dynamic_time_warping) were used for measuring similarity between sequences. However, now that we can train a network to compress a **trajectory** of a ship using a seq2seq autoencoder, we can use the output for various things.\n",
    "<img src=\"\" style=\"float:right; display:block; margin:10px\">\n",
    "#### Introducing G-means clustering\n",
    "\n",
    "So let's say we want to group similar trajectories of ships together using all available AIS data. It's hard to guess how many unique groups of routes exist for marine traffic, so a clustering algorithm like k-means is not useful. This is where the G-means algorithm has some utility.\n",
    "\n",
    "G-means will repeatedly test a group for Gaussian patterns. If the group tests positive, then it will split the group. This will continue to happen until the groups no longer appear Gaussian. There are also other methods for non-K-means analysis, but G-means is quite useful for our needs.\n",
    "\n",
    "#### Apache Spark\n",
    "\n",
    "Sometimes a single computer doesn't cut it for munging your data. [Hadoop](http://hadoop.apache.org/) was originally developed for storing and processing large amounts of data; however, with times comes innovation and [Apache Spark](http://spark.apache.org/) was eventually developed for faster large-scale data processing, touting up to a 100x improvement over Hadoop. The two frameworks aren't entirely identical - Spark doesn't have its own filesystem and often uses Hadoop's HDFS.\n",
    "\n",
    "Spark is also capable of SQL-like exploration of data with its spark-sql module. However, it is not unique in the ecosystem and other frameworks such as [Hive](https://hive.apache.org/) and [Pig](https://pig.apache.org/) have similar functionality. At the conceptual level, Hive and Pig make it easy to write map-reduce programs. However, Spark has largely become the de facto standard for data analysis and Pig has recently introduced a Spark integration."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### What are we going to learn in this tutorial?\n",
    "\n",
    "Using Deeplearning4j, DataVec, and some custom code you will learn how to cluster large amounts of AIS data. We will be using a local Spark cluster built-in to Zeppelin to execute DataVec preprocessing, train an autoencoder on the converted sequences, and finally use G-means on the compressed output and visualize the groups."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "autoscroll": "auto"
   },
   "outputs": [],
   "source": [
    "import org.deeplearning4j.nn.graph.ComputationGraph\n",
    "import org.deeplearning4j.nn.transferlearning.TransferLearning\n",
    "import org.deeplearning4j.nn.api.OptimizationAlgorithm\n",
    "import org.deeplearning4j.nn.weights.WeightInit\n",
    "import org.deeplearning4j.nn.conf._\n",
    "import org.deeplearning4j.nn.conf.layers._\n",
    "import org.deeplearning4j.nn.conf.graph.rnn._\n",
    "import org.deeplearning4j.nn.conf.inputs.InputType\n",
    "import org.deeplearning4j.nn.conf.WorkspaceMode\n",
    "import org.deeplearning4j.optimize.listeners.ScoreIterationListener\n",
    "import org.deeplearning4j.datasets.iterator.MultipleEpochsIterator\n",
    "import org.deeplearning4j.datasets.datavec.RecordReaderMultiDataSetIterator\n",
    "import org.deeplearning4j.util.ModelSerializer\n",
    "\n",
    "import org.datavec.api.transform._\n",
    "import org.datavec.api.transform.transform.time.StringToTimeTransform\n",
    "import org.datavec.api.transform.sequence.comparator.NumericalColumnComparator\n",
    "import org.datavec.api.transform.transform.string.ConcatenateStringColumns\n",
    "import org.datavec.api.transform.transform.doubletransform.MinMaxNormalizer\n",
    "import org.datavec.api.transform.schema.Schema\n",
    "import org.datavec.api.transform.metadata.StringMetaData\n",
    "import org.datavec.api.records.reader.impl.csv.CSVRecordReader\n",
    "import org.datavec.api.split.FileSplit\n",
    "import org.datavec.spark.storage.SparkStorageUtils\n",
    "import org.datavec.spark.transform.misc.StringToWritablesFunction\n",
    "import org.datavec.spark.transform.SparkTransformExecutor\n",
    "import org.datavec.api.transform.condition._\n",
    "import org.datavec.api.transform.condition.column._\n",
    "import org.datavec.api.transform.sequence.window.ReduceSequenceByWindowTransform\n",
    "import org.datavec.api.transform.reduce.Reducer\n",
    "import org.datavec.api.transform.reduce.AggregableColumnReduction\n",
    "import org.datavec.api.transform.sequence.window.TimeWindowFunction\n",
    "import org.datavec.api.transform.ops.IAggregableReduceOp\n",
    "import org.datavec.api.transform.metadata.ColumnMetaData\n",
    "import org.datavec.api.writable._\n",
    "import org.datavec.hadoop.records.reader.mapfile.MapFileSequenceRecordReader\n",
    "import org.datavec.api.util.ArchiveUtils\n",
    "\n",
    "import org.nd4j.linalg.api.ndarray.INDArray\n",
    "import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor\n",
    "import org.nd4j.linalg.dataset.api.MultiDataSet\n",
    "import org.nd4j.linalg.lossfunctions.LossFunctions\n",
    "import org.nd4j.linalg.activations.Activation\n",
    "import org.nd4j.linalg.learning.config._\n",
    "import org.nd4j.linalg.factory.Nd4j\n",
    "import org.nd4j.linalg.indexing.BooleanIndexing\n",
    "import org.nd4j.linalg.indexing.INDArrayIndex\n",
    "import org.nd4j.linalg.indexing.NDArrayIndex._\n",
    "import org.nd4j.linalg.indexing.conditions.Conditions\n",
    "\n",
    "import org.apache.spark.api.java.function.Function\n",
    "import org.apache.commons.io.FileUtils\n",
    "import org.joda.time.DateTimeZone\n",
    "import org.joda.time.format.DateTimeFormat\n",
    "\n",
    "import scala.collection.JavaConversions._\n",
    "import scala.collection.JavaConverters._\n",
    "import scala.io.Source\n",
    "import java.util.Random\n",
    "import java.util.concurrent.TimeUnit\n",
    "import java.io._\n",
    "import java.net.URL\n",
    "\n",
    "val cache = new File(System.getProperty(\"user.home\"), \"/.deeplearning4j\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Download the dataset\n",
    "\n",
    "The file we will be downloading is nearly 2GB uncompressed, make sure you have enough space on your local disk. If you want to check out the file yourself, you can download a copy from [http://blob.deeplearning4j.org/datasets/aisdk_20171001.csv.zip](http://blob.deeplearning4j.org/datasets/aisdk_20171001.csv). The code below will check if the data already exists and download the file."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "autoscroll": "auto"
   },
   "outputs": [],
   "source": [
    "val dataFile = new File(cache, \"/aisdk_20171001.csv\")\n",
    "\n",
    "if(!dataFile.exists()) {\n",
    "    val remote = \"http://blob.deeplearning4j.org/datasets/aisdk_20171001.csv.zip\"\n",
    "    val tmpZip = new File(cache, \"aisdk_20171001.csv.zip\")\n",
    "    tmpZip.delete() // prevents errors\n",
    "    println(\"Downloading file...\")\n",
    "    FileUtils.copyURLToFile(new URL(remote), tmpZip)\n",
    "    println(\"Decompressing file...\")\n",
    "    ArchiveUtils.unzipFileTo(tmpZip.getAbsolutePath(), cache.getAbsolutePath())\n",
    "    tmpZip.delete()\n",
    "    println(\"Done.\")\n",
    "} else {\n",
    "    println(\"File already exists.\")\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Examine sequence lengths\n",
    "\n",
    "The trouble with raw data is that it usually doesn't have the clean structure that you would expect for an example. It's useful to investigate the structure of the data, calculate some basic statistics on average sequence length, and figure out the complexity of the raw data.\n",
    "\n",
    "Below we count the length of each sequence and plot the distribution. *You will see that this is very problematic. The longest sequence in the data is 36,958 time steps!*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "autoscroll": "auto"
   },
   "outputs": [],
   "source": [
    "val raw = sqlContext.read\n",
    "    .format(\"com.databricks.spark.csv\")\n",
    "    .option(\"header\", \"true\") // Use first line of all files as header\n",
    "    .option(\"inferSchema\", \"true\") // Automatically infer data types\n",
    "    .load(dataFile.getAbsolutePath)\n",
    "\n",
    "import org.apache.spark.sql.functions._\n",
    "\n",
    "val positions = raw\n",
    "    .withColumn(\"Timestamp\", unix_timestamp(raw(\"# Timestamp\"), \"dd/MM/YYYY HH:mm:ss\"))\n",
    "    .select(\"Timestamp\",\"MMSI\",\"Longitude\",\"Latitude\")\n",
    "    \n",
    "positions.printSchema    \n",
    "positions.registerTempTable(\"positions\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "autoscroll": "auto"
   },
   "outputs": [],
   "source": [
    "val sequences = positions\n",
    "    .rdd\n",
    "    .map( row => (row.getInt(1), (row.getLong(0), (row.getDouble(3), row.getDouble(2)))) ) // a tuple of ship ID and timed coordinates\n",
    "    .groupBy(_._1)\n",
    "    .map( group => (group._1, group._2.map(pos => pos._2).toSeq.sortBy(_._1)))\n",
    "    \n",
    "case class Stats(numPositions: Int, minTime: Long, maxTime: Long, totalTime: Long)\n",
    "\n",
    "val stats = sequences\n",
    "    .map { seq => \n",
    "        val timestamps = seq._2.map(_._1).toArray\n",
    "        Stats(seq._2.size, timestamps.min, timestamps.max, (timestamps.max-timestamps.min))\n",
    "    }\n",
    "    .toDF()\n",
    "stats.registerTempTable(\"stats\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%sql\n",
    "\n",
    "select numPositions, count(1) value\n",
    "from stats \n",
    "where numPositions < 65\n",
    "group by numPositions\n",
    "order by numPositions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%sql\n",
    "\n",
    "select avg(totalTime), avg(minTime), avg(maxTime), avg(numPositions)\n",
    "from stats"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%sql\n",
    "\n",
    "select floor(Longitude/10)*10 as bin_floor, count(1) value\n",
    "from positions \n",
    "group by 1\n",
    "order by 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%sql\n",
    "\n",
    "select min(Latitude), max(Latitude), min(Longitude), max(Longitude)\n",
    "from positions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Extract and transform\n",
    "\n",
    "Now that we've examined our data, we need to extract it from the CSV and transform it into sequence data. DataVec and Spark make this easy to use for us.\n",
    "\n",
    "Using DataVec's `Schema` class we define the schema of the data and their columns. Alternatively, if you have a sample file of the data you can also use the `InferredSchema` class. Afterwards, we can build a `TransformProcess` that removes any unwanted fields and uses a comparison of timestamps to create sequences for each unique ship in the AIS data.\n",
    "\n",
    "Once we're certain that the schema and transformations are what we want, we can read the CSV into a [Spark RDD](http://spark.apache.org/docs/2.1.1/programming-guide.html#resilient-distributed-datasets-rdds) and execute our transformation with DataVec. First, we convert the data to a sequence with `convertToSequence()` and a numerical comparator to sort by timestamp. Then we apply a window function to each sequence to reduce those windows to a single value. This helps reduce the variability in sequence lengths, which will be problematic when we go to train our autoencoder.\n",
    "\n",
    "If you want to use the Scala-style method of programming, you can switch back and forth between the Scala and Java APIs for the Spark RDD. Calling `.rdd` on a `JavaRDD` will return a regular `RDD` Scala class. If you prefer the Java API, call `toJavaRDD()` on a `RDD`.\n",
    "\n",
    "#### Filtering of trajectories\n",
    "\n",
    "To reduce the complexity of this tutorial, we will be omitting anomalous trajectories. In the analysis above you'll see that there is a significant number of trajectories with invalid positions. Latitude and longitude coordinates do not exceed the -90,90 and -180,180 ranges respectively; therefore, we filter them. Additionally, many of the trajectories only include a handful of positions - we will eliminate sequences that are too short for meaningful representation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "autoscroll": "auto"
   },
   "outputs": [],
   "source": [
    "// our reduction op class that we will need shortly\n",
    "// due to interpreter restrictions, we put this inside an object\n",
    "object Reductions extends Serializable {\n",
    "    class GeoAveragingReduction(val columnOutputName: String=\"AveragedLatLon\", val delim: String=\",\") extends AggregableColumnReduction {\n",
    "        \n",
    "        override def reduceOp(): IAggregableReduceOp[Writable, java.util.List[Writable]] = {\n",
    "            new AverageCoordinateReduceOp(delim)\n",
    "        }\n",
    "        \n",
    "        override def getColumnsOutputName(inputName: String): java.util.List[String] = List(columnOutputName)\n",
    "        \n",
    "        override def getColumnOutputMetaData(newColumnName: java.util.List[String], columnInputMeta: ColumnMetaData): java.util.List[ColumnMetaData] = \n",
    "            List(new StringMetaData(columnOutputName))\n",
    "        \n",
    "        override def transform(inputSchema: Schema) = inputSchema\n",
    "        \n",
    "        override def outputColumnName: String = null\n",
    "        \n",
    "        override def outputColumnNames: Array[String] = new Array[String](0)\n",
    "        \n",
    "        override def columnNames: Array[String] = new Array[String](0)\n",
    "        \n",
    "        override def columnName: String = null\n",
    "          \n",
    "        def getInputSchema(): org.datavec.api.transform.schema.Schema = ???\n",
    "        \n",
    "        def setInputSchema(x$1: org.datavec.api.transform.schema.Schema): Unit = ???\n",
    "    }\n",
    "    \n",
    "    class AverageCoordinateReduceOp(val delim: String) extends IAggregableReduceOp[Writable, java.util.List[Writable]] {\n",
    "        final val PI_180 = Math.PI / 180.0\n",
    "\n",
    "        var sumx = 0.0\n",
    "        var sumy = 0.0\n",
    "        var sumz = 0.0\n",
    "        var count = 0\n",
    "        \n",
    "        override def combine[W <: IAggregableReduceOp[Writable, java.util.List[Writable]]](accu: W): Unit = {\n",
    "          if (accu.isInstanceOf[AverageCoordinateReduceOp]) {\n",
    "            val r: AverageCoordinateReduceOp =\n",
    "              accu.asInstanceOf[AverageCoordinateReduceOp]\n",
    "            sumx += r.sumx\n",
    "            sumy += r.sumy\n",
    "            sumz += r.sumz\n",
    "            count += r.count\n",
    "          } else {\n",
    "            throw new IllegalStateException(\n",
    "              \"Cannot combine type of class: \" + accu.getClass)\n",
    "          }\n",
    "        }\n",
    "\n",
    "        override def accept(writable: Writable): Unit = {\n",
    "          val str: String = writable.toString\n",
    "          val split: Array[String] = str.split(delim)\n",
    "          if (split.length != 2) {\n",
    "            throw new IllegalStateException(\n",
    "              \"Could not parse lat/long string: \\\"\" + str + \"\\\"\")\n",
    "          }\n",
    "          val latDeg: Double = java.lang.Double.parseDouble(split(0))\n",
    "          val longDeg: Double = java.lang.Double.parseDouble(split(1))\n",
    "          val lat: Double = latDeg * PI_180\n",
    "          val lng: Double = longDeg * PI_180\n",
    "          val x: Double = Math.cos(lat) * Math.cos(lng)\n",
    "          val y: Double = Math.cos(lat) * Math.sin(lng)\n",
    "          val z: Double = Math.sin(lat)\n",
    "          sumx += x\n",
    "          sumy += y\n",
    "          sumz += z\n",
    "          count += 1\n",
    "        }\n",
    "\n",
    "        override def get(): java.util.List[Writable] = {\n",
    "          val x: Double = sumx / count\n",
    "          val y: Double = sumy / count\n",
    "          val z: Double = sumz / count\n",
    "          val longRad: Double = Math.atan2(y, x)\n",
    "          val hyp: Double = Math.sqrt(x * x + y * y)\n",
    "          val latRad: Double = Math.atan2(z, hyp)\n",
    "          val latDeg: Double = latRad / PI_180\n",
    "          val longDeg: Double = longRad / PI_180\n",
    "          val str: String = latDeg + delim + longDeg\n",
    "          List(new Text(str))\n",
    "        }\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "autoscroll": "auto"
   },
   "outputs": [],
   "source": [
    "// note the column names don't exactly match, we are arbitrarily assigning them\n",
    "val schema = new Schema.Builder()\n",
    "    .addColumnsString(\"Timestamp\")\n",
    "    .addColumnCategorical(\"VesselType\")\n",
    "    .addColumnsString(\"MMSI\")\n",
    "    .addColumnsDouble(\"Lat\",\"Lon\") // will convert to Double later\n",
    "    .addColumnCategorical(\"Status\")\n",
    "    .addColumnsDouble(\"ROT\",\"SOG\",\"COG\")\n",
    "    .addColumnInteger(\"Heading\")\n",
    "    .addColumnsString(\"IMO\",\"Callsign\",\"Name\")\n",
    "    .addColumnCategorical(\"ShipType\",\"CargoType\")\n",
    "    .addColumnsInteger(\"Width\",\"Length\")\n",
    "    .addColumnCategorical(\"FixingDevice\")\n",
    "    .addColumnDouble(\"Draught\")\n",
    "    .addColumnsString(\"Destination\",\"ETA\")\n",
    "    .addColumnCategorical(\"SourceType\")\n",
    "    .addColumnsString(\"end\")\n",
    "    .build()\n",
    "    \n",
    "val transform = new TransformProcess.Builder(schema)\n",
    "    .removeAllColumnsExceptFor(\"Timestamp\",\"MMSI\",\"Lat\",\"Lon\")\n",
    "    .filter(BooleanCondition.OR(new DoubleColumnCondition(\"Lat\",ConditionOp.GreaterThan,90.0), new DoubleColumnCondition(\"Lat\",ConditionOp.LessThan,-90.0))) // remove erroneous lat\n",
    "    .filter(BooleanCondition.OR(new DoubleColumnCondition(\"Lon\",ConditionOp.GreaterThan,180.0), new DoubleColumnCondition(\"Lon\",ConditionOp.LessThan,-180.0))) // remove erroneous lon\n",
    "    .transform(new MinMaxNormalizer(\"Lat\", -90.0,\t90.0, 0.0, 1.0))\n",
    "    .transform(new MinMaxNormalizer(\"Lon\", -180.0,\t180.0, 0.0, 1.0))\n",
    "    .convertToString(\"Lat\")\n",
    "    .convertToString(\"Lon\")\n",
    "    .transform(new StringToTimeTransform(\"Timestamp\",\"dd/MM/YYYY HH:mm:ss\",DateTimeZone.UTC))\n",
    "    .transform(new ConcatenateStringColumns(\"LatLon\", \",\", List(\"Lat\",\"Lon\")))\n",
    "    .convertToSequence(\"MMSI\", new NumericalColumnComparator(\"Timestamp\", true))\n",
    "    .transform(\n",
    "        new ReduceSequenceByWindowTransform(\n",
    "            new Reducer.Builder(ReduceOp.Count).keyColumns(\"MMSI\")\n",
    "                .countColumns(\"Timestamp\")\n",
    "                .customReduction(\"LatLon\", new Reductions.GeoAveragingReduction(\"LatLon\"))\n",
    "                .takeFirstColumns(\"Timestamp\")\n",
    "                .build(),\n",
    "            new TimeWindowFunction.Builder()\n",
    "                .timeColumn(\"Timestamp\")\n",
    "                .windowSize(1L ,TimeUnit.HOURS)\n",
    "                .excludeEmptyWindows(true)\n",
    "                .build()\n",
    "        )\n",
    "    )\n",
    "    .removeAllColumnsExceptFor(\"LatLon\")\n",
    "    .build\n",
    "    \n",
    "// note we temporarily switch between java/scala APIs for convenience\n",
    "val rawData = sc\n",
    "    .textFile(dataFile.getAbsolutePath)\n",
    "    .filter(row => !row.startsWith(\"# Timestamp\")) // filter out the header  \n",
    "    .toJavaRDD // datavec API uses Spark's Java API\n",
    "    .map(new StringToWritablesFunction(new CSVRecordReader()))\n",
    "    \n",
    "// once transform is applied, filter sequences we consider \"too short\"\n",
    "// decombine lat/lon then convert to arrays and split, then convert back to java APIs\n",
    "val records = SparkTransformExecutor\n",
    "    .executeToSequence(rawData,transform)\n",
    "    .rdd\n",
    "    .filter( seq => seq.size() > 7)\n",
    "    .map{ row: java.util.List[java.util.List[Writable]] =>\n",
    "        row.map{ seq => seq.map(_.toString).map(_.split(\",\").toList.map(coord => new DoubleWritable(coord.toDouble).asInstanceOf[Writable])).flatten }\n",
    "    }\n",
    "    .map(_.toList.map(_.asJava).asJava)\n",
    "    .toJavaRDD\n",
    "\n",
    "val split = records.randomSplit(Array[Double](0.8,0.2))\n",
    "    \n",
    "val trainSequences = split(0)\n",
    "val testSequences = split(1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Iteration and storage options\n",
    "\n",
    "Once you have finished preprocessing your dataset, you have a couple options to serialize your dataset before feeding it to your autoencoder network via an iterator. This applies to both unsupervised and supervised learning.\n",
    "\n",
    "1. **Save to Hadoop Map File.** This serializes the dataset to the hadoop map file format and writes it to disk. You can do this whether your training network will be on a local, single node or distributed across multiple nodes via Spark. The advantage here is you can preprocess your dataset once, and read from the map file as much as necessary.\n",
    "2. **Pass the `RDD` to a `RecordReaderMultiDataSetIterator`.** If you prefer to read your dataset directly from Spark, you can pass your RDD to a `RecordReaderMultiDataSetIterator`. The `SparkSourceDummyReader` class acts as a placeholder for each source of records. This process will convert the records to a `MultiDataSet` which can then be passed to a distributed neural network such as `SparkComputationGraph`.\n",
    "3. **Serialize to another format.** There are other options for serializing a dataset which will not be discussed here. They include saving the `INDArray` data in a compressed format on disk or using a proprietary method you create yourself.\n",
    "\n",
    "This example uses method 1 above. We'll assume you have a single machine instance for training your network. Note: you can always mix architectures for preprocessing and training (Spark vs. GPU cluster). It really depends on what hardware you have available."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "autoscroll": "auto"
   },
   "outputs": [],
   "source": [
    "// for purposes of this notebook, you only need to run this block once\n",
    "val trainFiles = new File(cache, \"/ais_trajectories_train/\")\n",
    "val testFiles = new File(cache, \"/ais_trajectories_test/\")\n",
    "\n",
    "// if you want to delete previously saved data\n",
    "// FileUtils.deleteDirectory(trainFiles)\n",
    "// FileUtils.deleteDirectory(testFiles)\n",
    "\n",
    "if(!trainFiles.exists()) SparkStorageUtils.saveMapFileSequences( trainFiles.getAbsolutePath(), trainSequences )\n",
    "if(!testFiles.exists()) SparkStorageUtils.saveMapFileSequences( testFiles.getAbsolutePath(), testSequences )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Iterating from disk\n",
    "\n",
    "Now that we've saved our dataset to a Hadoop Map File, we need to set up a `RecordReader` and iterator that will read our saved sequences and feed them to our autoencoder. Conveniently, if you have already saved your data to disk, you can run this code block (and remaining code blocks) as much as you want without preprocessing the dataset again."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "autoscroll": "auto"
   },
   "outputs": [],
   "source": [
    "// set up record readers that will read the features from disk\n",
    "val batchSize = 48\n",
    "\n",
    "// this preprocessor allows for insertion of GO/STOP tokens for the RNN\n",
    "object Preprocessor extends Serializable {\n",
    "    class Seq2SeqAutoencoderPreProcessor extends MultiDataSetPreProcessor {\n",
    "\n",
    "        override def preProcess(mds: MultiDataSet): Unit = {\n",
    "            val input: INDArray = mds.getFeatures(0)\n",
    "            val features: Array[INDArray] = Array.ofDim[INDArray](2)\n",
    "            val labels: Array[INDArray] = Array.ofDim[INDArray](1)\n",
    "            \n",
    "            features(0) = input\n",
    "            \n",
    "            val mb: Int = input.size(0)\n",
    "            val nClasses: Int = input.size(1)\n",
    "            val origMaxTsLength: Int = input.size(2)\n",
    "            val goStopTokenPos: Int = nClasses\n",
    "            \n",
    "            //1 new class, for GO/STOP. And one new time step for it also\n",
    "            val newShape: Array[Int] = Array(mb, nClasses + 1, origMaxTsLength + 1)\n",
    "            features(1) = Nd4j.create(newShape:_*)\n",
    "            labels(0) = Nd4j.create(newShape:_*)\n",
    "            //Create features. Append existing at time 1 to end. Put GO token at time 0\n",
    "            features(1).put(Array[INDArrayIndex](all(), interval(0, input.size(1)), interval(1, newShape(2))), input)\n",
    "            //Set GO token\n",
    "            features(1).get(all(), point(goStopTokenPos), all()).assign(1)\n",
    "            //Create labels. Append existing at time 0 to end-1. Put STOP token at last time step - **Accounting for variable length / masks**\n",
    "            labels(0).put(Array[INDArrayIndex](all(), interval(0, input.size(1)), interval(0, newShape(2) - 1)), input)\n",
    "            \n",
    "            var lastTimeStepPos: Array[Int] = null\n",
    "            \n",
    "            if (mds.getFeaturesMaskArray(0) == null) {//No masks\n",
    "                lastTimeStepPos = Array.ofDim[Int](input.size(0))\n",
    "                for (i <- 0 until lastTimeStepPos.length) {\n",
    "                  lastTimeStepPos(i) = input.size(2) - 1\n",
    "                }\n",
    "            } else {\n",
    "                val fm: INDArray = mds.getFeaturesMaskArray(0)\n",
    "                val lastIdx: INDArray = BooleanIndexing.lastIndex(fm, Conditions.notEquals(0), 1)\n",
    "                lastTimeStepPos = lastIdx.data().asInt()\n",
    "            }\n",
    "            for (i <- 0 until lastTimeStepPos.length) {\n",
    "                labels(0).putScalar(i, goStopTokenPos, lastTimeStepPos(i), 1.0)\n",
    "            }\n",
    "            //In practice: Just need to append an extra 1 at the start (as all existing time series are now 1 step longer)\n",
    "            var featureMasks: Array[INDArray] = null\n",
    "            var labelsMasks: Array[INDArray] = null\n",
    "            \n",
    "            if (mds.getFeaturesMaskArray(0) != null) {//Masks are present - variable length\n",
    "                featureMasks = Array.ofDim[INDArray](2)\n",
    "                featureMasks(0) = mds.getFeaturesMaskArray(0)\n",
    "                labelsMasks = Array.ofDim[INDArray](1)\n",
    "                val newMask: INDArray = Nd4j.hstack(Nd4j.ones(mb, 1), mds.getFeaturesMaskArray(0))\n",
    "                // println(mds.getFeaturesMaskArray(0).shape())\n",
    "                // println(newMask.shape())\n",
    "                featureMasks(1) = newMask\n",
    "                labelsMasks(0) = newMask\n",
    "            } else {\n",
    "                //All same length\n",
    "                featureMasks = null\n",
    "                labelsMasks = null\n",
    "            }\n",
    "            //Same for labels\n",
    "            mds.setFeatures(features)\n",
    "            mds.setLabels(labels)\n",
    "            mds.setFeaturesMaskArrays(featureMasks)\n",
    "            mds.setLabelsMaskArray(labelsMasks)\n",
    "        }\n",
    "        \n",
    "    }\n",
    "}\n",
    "\n",
    "// because this is an autoencoder, features = labels\n",
    "val trainRR = new MapFileSequenceRecordReader()\n",
    "trainRR.initialize(new FileSplit(trainFiles))\n",
    "val trainIter = new RecordReaderMultiDataSetIterator.Builder(batchSize)\n",
    "            .addSequenceReader(\"records\", trainRR)\n",
    "            .addInput(\"records\")\n",
    "            .build()\n",
    "trainIter.setPreProcessor(new Preprocessor.Seq2SeqAutoencoderPreProcessor)\n",
    "            \n",
    "val testRR = new MapFileSequenceRecordReader()\n",
    "testRR.initialize(new FileSplit(testFiles))\n",
    "val testIter = new RecordReaderMultiDataSetIterator.Builder(batchSize)\n",
    "            .addSequenceReader(\"records\", testRR)\n",
    "            .addInput(\"records\")\n",
    "            .build()\n",
    "testIter.setPreProcessor(new Preprocessor.Seq2SeqAutoencoderPreProcessor)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Build the autoencoder\n",
    "\n",
    "Now that we've prepared our data, we must construct the sequence-to-sequence autoencoder. The configuration is quite similar to the autoencoders in other tutorials, except layers primarily use LSTMs. Note that in this architecture we use a `DuplicateToTimeSeriesVertex` between our encoder and decoder. This allows us to iteratively generate while each time step will get the same input but a different hidden state."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "autoscroll": "auto"
   },
   "outputs": [],
   "source": [
    "val conf = new NeuralNetConfiguration.Builder()\n",
    "                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)\n",
    "                .iterations(1)\n",
    "                .seed(123)\n",
    "                .regularization(true)\n",
    "                .l2(0.001)\n",
    "                .weightInit(WeightInit.XAVIER)\n",
    "                .updater(new AdaDelta())\n",
    "                .inferenceWorkspaceMode(WorkspaceMode.SINGLE)\n",
    "                .trainingWorkspaceMode(WorkspaceMode.SINGLE)\n",
    "                .graphBuilder()\n",
    "                .addInputs(\"encoderInput\",\"decoderInput\")\n",
    "                .setInputTypes(InputType.recurrent(2), InputType.recurrent(3))\n",
    "                .addLayer(\"encoder\", new GravesLSTM.Builder().nOut(96).activation(Activation.TANH).build(), \"encoderInput\")\n",
    "                .addLayer(\"encoder2\", new GravesLSTM.Builder().nOut(48).activation(Activation.TANH).build(), \"encoder\")\n",
    "                .addVertex(\"laststep\", new LastTimeStepVertex(\"encoderInput\"), \"encoder2\")\n",
    "                .addVertex(\"dup\", new DuplicateToTimeSeriesVertex(\"decoderInput\"), \"laststep\")\n",
    "                .addLayer(\"decoder\", new GravesLSTM.Builder().nOut(48).activation(Activation.TANH).build(), \"decoderInput\", \"dup\")\n",
    "                .addLayer(\"decoder2\", new GravesLSTM.Builder().nOut(96).activation(Activation.TANH).build(), \"decoder\")\n",
    "                .addLayer(\"output\", new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).activation(Activation.SIGMOID).nOut(3).build(), \"decoder2\")\n",
    "                .setOutputs(\"output\")\n",
    "                .build()\n",
    "    \n",
    "val net = new ComputationGraph(conf)\n",
    "net.setListeners(new ScoreIterationListener(1))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " \n",
    "\n",
    "### Unsupervised training\n",
    "\n",
    "Now that the network configruation is set up and instantiated along with our iterators, training takes just a few lines of code.\n",
    "\n",
    "Earlier we attached a `ScoreIterationListener` to the model by using the `setListeners()` method. Depending on the browser you are using to run this notebook, you can open the debugger/inspector to view listener output. This output is redirected to the console since the internals of Deeplearning4j use SL4J for logging, and the output is being redirected by Zeppelin. This is a good thing since it can reduce clutter in notebooks.\n",
    "\n",
    "After each epoch, we will evaluate how well the network is learning by using the `evaluate()` method. Although here we only use `accuracy()` and `precision()`, it is strongly recommended you learn how to do advanced evaluation with ROC curves and understand the output from a confusion matrix."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "autoscroll": "auto"
   },
   "outputs": [],
   "source": [
    "// pass the training iterator to fit() to watch the network learn one epoch of training\n",
    "net.fit(train)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train for multiple epochs\n",
    "<img src=\"\" style=\"float: right; display: block; margin: 5px 0 0 10px\">\n",
    "\n",
    "Deeplearning4j has a built-in `MultipleEpochsIterator` that automatically handles multiple epochs of training. Alternatively, if you instead want to handle per-epoch events you can either use an `EarlyStoppingGraphTrainer` which listens for scoring events, or wrap `net.fit()` in a for-loop yourself.\n",
    "\n",
    "Below, we manually create a for-loop since our iterator requires a more complex `MultiDataSet`. This is because our seq2seq autoencoder uses multiple inputs/outputs.\n",
    "\n",
    "The autoencoder here has been tuned to converge with an average reconstruction error of approximately 2% when trained for 35+ epochs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "autoscroll": "auto"
   },
   "outputs": [],
   "source": [
    "// we will pass our training data to an iterator that can handle multiple epochs of training\n",
    "val numEpochs = 150\n",
    "\n",
    "(1 to numEpochs).foreach { i =>\n",
    "    net.fit(trainIter)\n",
    "    println(s\"Finished epoch $i\")\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Save your model\n",
    "\n",
    "At this point, you've invested a lot of time and computation building your autoencoder. Saving it to disk and restoring it is quite simple."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {
    "autoscroll": "auto"
   },
   "outputs": [],
   "source": [
    "val modelFile = new File(cache, \"/seq2seqautoencoder.zip\")\n",
    "// write to disk\n",
    "ModelSerializer.writeModel(net, modelFile, false)\n",
    "// restore from disk\n",
    "val net = ModelSerializer.restoreComputationGraph(modelFile)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Compare reconstructed outputs\n",
    "\n",
    "Below we build a loop to visualize just how well our autoencoder is able to reconstruct the original sequences. After forwarding a single example, we score the reconstruction and then compare the original array to the reconstructed array. Note that we need to do some string formatting, otherwise when we try to print the array we will get garbled output - this is actually a reference to the array object in memory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {
    "autoscroll": "auto"
   },
   "outputs": [],
   "source": [
    "def arr2Dub(arr: INDArray): Array[Double] = arr.dup().data().asDouble()\n",
    "val format = new java.text.DecimalFormat(\"#.##\")\n",
    "\n",
    "testIter.reset()\n",
    "(0 to 10).foreach{ i =>\n",
    "    val mds = testIter.next(1)\n",
    "    val reconstructionError = net.score(mds)\n",
    "    val output = net.feedForward(mds.getFeatures(), false)\n",
    "    val feat = arr2Dub(mds.getFeatures(0))\n",
    "    val orig = feat.map(format.format(_)).mkString(\",\")\n",
    "    val recon = arr2Dub(output.get(\"output\")).map(format.format(_)).take(feat.size).mkString(\",\")\n",
    "    \n",
    "    println(s\"Reconstruction error for example $i is $reconstructionError\")\n",
    "    println(s\"Original array:        $orig\")\n",
    "    println(s\"Reconstructed array:   $recon\")\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Transferring the parameters\n",
    "\n",
    "Now that the network has been trained, we will extract the encoder from the network. This is so we can construct a new network for exclusive representation encoding."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {
    "autoscroll": "auto"
   },
   "outputs": [],
   "source": [
    "// use the GraphBuilder when your network is a ComputationGraph\n",
    "val encoder = new TransferLearning.GraphBuilder(net)\n",
    "    .setFeatureExtractor(\"laststep\")\n",
    "    .removeVertexAndConnections(\"decoder-merge\")\n",
    "    .removeVertexAndConnections(\"decoder\")\n",
    "    .removeVertexAndConnections(\"decoder2\")\n",
    "    .removeVertexAndConnections(\"output\")\n",
    "    .removeVertexAndConnections(\"dup\")\n",
    "    .addLayer(\"output\", new ActivationLayer.Builder().activation(Activation.IDENTITY).build(), \"laststep\")\n",
    "    .setOutputs(\"output\")\n",
    "    .setInputs(\"encoderInput\")\n",
    "    .setInputTypes(InputType.recurrent(2))\n",
    "    .build()\n",
    "    \n",
    "// grab a single batch to test feed forward\n",
    "val ds = testIter.next(1)\n",
    "val embedding = encoder.feedForward(ds.getFeatures(0), false)\n",
    "val shape = embedding.get(\"output\").shape().mkString(\",\")\n",
    "val dsFeat = arr2Dub(ds.getFeatures(0))\n",
    "val dsOrig = dsFeat.map(format.format(_)).mkString(\",\")\n",
    "val rep = arr2Dub(embedding.get(\"output\")).map(format.format(_)).take(dsFeat.size).mkString(\",\")\n",
    "\n",
    "println(s\"Compressed shape:       $shape\")\n",
    "println(s\"Original array:        $dsOrig\")\n",
    "println(s\"Compressed array:      $rep\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Clustering the output\n",
    "\n",
    "Homestretch! We're now able to take the compressed representations of our trajectories and start to cluster them together. As mentioned earlier, a non-K clustering algorithm is preferable.\n",
    "\n",
    "The [Smile Scala library](https://haifengl.github.io/smile/clustering.html#g-means) has a number of clustering methods already available and we'll be using it for grouping our trajectories."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {
    "autoscroll": "auto"
   },
   "outputs": [],
   "source": [
    "// first we need to grab our representations\n",
    "// in a \"real world\" scenario we'd want something more elegant that preserves our MMSIs\n",
    "val dataset = scala.collection.mutable.ListBuffer.empty[Array[Double]]\n",
    "testIter.reset()\n",
    "\n",
    "while(testIter.hasNext()) {\n",
    "    val ds = testIter.next(1)\n",
    "    val rep = encoder.feedForward(ds.getFeatures(0), false)\n",
    "    dataset += rep.get(\"output\").dup.data.asDouble\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {
    "autoscroll": "auto"
   },
   "outputs": [],
   "source": [
    "import smile.clustering.GMeans\n",
    "\n",
    "val maxClusterNumber = 1000\n",
    "val gmeans = new GMeans(dataset.toArray, maxClusterNumber)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Visualizing the output\n",
    "<img style=\"float:right; display: block; margin: 0 0 10px 10px\" src=\"\">\n",
    "\n",
    "Visualizing the clusters requires one extra step. Our seq2seq autoencoder produces representations that are higher than 2 or 3 dimensions, meaning you will need to use an algorithm such as t-SNE to further reduce the dimensionality and generate “coordinates” that can be used for plotting. The pipeline would first involve clustering your encoded representations with G-means, then feeding the output to t-SNE to reduce the dimensionality of each representation so it can be plotted.\n",
    "\n",
    "#### Interpreting the result\n",
    "\n",
    "You may be thinking, \"do these clusters make sense?\" This is where further exploration is required. You'll need to go back to your clusters, identify the ships belonging to each one, and compare the ships within each cluster. If your encoder and clustering pipeline worked, you'll notice patterns such as:\n",
    "\n",
    "- ships crossing the English channel are grouped together\n",
    "- boats parked in marinas also cluster together\n",
    "- trans-atlantic ships also tend to cluster together"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {
    "autoscroll": "auto"
   },
   "outputs": [],
   "source": [
    "import smile.manifold.TSNE\n",
    "import smile.clustering.GMeans\n",
    "\n",
    "print(\"%table x\\ty\\tgroup\\tsize\") // this must come before any output\n",
    "\n",
    "val tsne = new TSNE(dataset.toArray, 2); // 2D plot\n",
    "val coordinates = tsne.getCoordinates();\n",
    "val gmeans = new GMeans(coordinates, 1000)\n",
    "\n",
    "(0 to coordinates.length-1).foreach{ i =>\n",
    "    val x = coordinates(i)(0)\n",
    "    val y = coordinates(i)(1)\n",
    "    val label = gmeans.getClusterLabel()(i)\n",
    "    print(s\"$x\\t$y\\t$label\\t1\\n\")\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### What's next?\n",
    "\n",
    "- Check out all of our tutorials available [on Github](https://github.com/deeplearning4j/dl4j-examples/tree/master/tutorials). Notebooks are numbered for easy following."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Spark 2.0.0 - Scala 2.11",
   "language": "scala",
   "name": "spark2-scala"
  },
  "language_info": {
   "codemirror_mode": "text/x-scala",
   "file_extension": ".scala",
   "mimetype": "text/x-scala",
   "name": "scala",
   "pygments_lexer": "scala",
   "version": "2.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
